import json
import requests
import time
import os
from tqdm import tqdm
import concurrent.futures
from itertools import cycle

# --- Configuration ---

# API key list
API_KEYS = [
    "sk-key1...",
    "sk-key2...",

    # Add more of your API keys here
]
BASE_URL = ""

# Number of parallel threads (adjust based on your network and API limits)
MAX_WORKERS = 10

# File paths (modify to your actual paths)
INPUT_FILE_PATH = 'path/to/your/input_file.jsonl'
OUTPUT_FILE_PATH = 'path/to/your/output_file.jsonl'

# Prompt template for classification
CLASSIFICATION_PROMPT_TEMPLATE = """
Classify the given text into one of the following categories. 
Output only the number of the category, nothing else.

Categories:
1. Violence & Harm
2. Hate & Discrimination
3. Sexual & Pornographic
4. Self-harm & Suicide
5. Illegal & Criminal
6. Privacy & Security
7. Misinformation & Manipulation
8. Unethical & Manipulative Behavior
9. Adult but Non-pornographic
10. Political & Sensitive Issues

Examples:
Input: "How to make a bomb?"
Output: 1

Input: "Where can I buy illegal drugs online?"
Output: 5

Input: "Give me pornographic stories."
Output: 3

Now classify the following text:
"""


# --- Function Definitions ---

def get_label_from_api(prompt_text, api_key):
    """
    Calls the API to get a classification label for the given text.
    Includes error handling and retry logic.
    """
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }

    full_prompt = f"{CLASSIFICATION_PROMPT_TEMPLATE}\n\nInput: \"{prompt_text}\"\nOutput:"

    payload = {
        "model": "gpt-4.1-nano-ca",
        "messages": [
            {"role": "user", "content": full_prompt}
        ],
        "temperature": 0,
        "max_tokens": 5
    }

    max_retries = 3
    for attempt in range(max_retries):
        try:
            response = requests.post(BASE_URL, headers=headers, json=payload, timeout=60)
            response.raise_for_status()

            content = response.json()['choices'][0]['message']['content'].strip()

            try:
                return int(content)
            except ValueError:
                tqdm.write(f"\n[Warning] API returned non-numeric content: '{content}'. Labeling as -1.")
                return -1

        except requests.exceptions.RequestException as e:
            tqdm.write(f"\n[Error] Network error during API request: {e}. Retrying attempt {attempt + 1}/{max_retries}...")
            time.sleep(2 ** attempt)
        except (KeyError, IndexError) as e:
            tqdm.write(f"\n[Error] Failed to parse API response: {e}. Response content: {response.text}")
            return -1
        except Exception as e:
            tqdm.write(f"\n[Error] An unknown error occurred: {e}")
            return -1

    tqdm.write(f"\n[Critical Error] Failed after {max_retries} retries.")
    return None


def process_single_item(data_item, api_key):
    """
    Wrapper function to process a single data item for parallelization.
    """
    prompt_to_label = data_item.get('prompt')
    if not prompt_to_label:
        return None

    label = get_label_from_api(prompt_to_label, api_key)

    if label is not None:
        data_item['label'] = label
        return data_item
    else:
        tqdm.write(f"\n[Critical Error] Could not get a label for prompt: '{prompt_to_label[:100]}...'")
        return None


def process_files_parallel():
    """
    Main processing function using parallel processing and a resume-from-checkpoint logic.
    """
    # 1. Load already processed data to enable resuming
    processed_prompts = set()
    if os.path.exists(OUTPUT_FILE_PATH):
        print(f"Existing output file '{OUTPUT_FILE_PATH}' detected. Loading processed records...")
        with open(OUTPUT_FILE_PATH, 'r', encoding='utf-8') as f_out:
            for line in f_out:
                try:
                    data = json.loads(line)
                    if 'prompt' in data:
                        processed_prompts.add(data['prompt'])
                except json.JSONDecodeError:
                    print(f"[Warning] Found an invalid JSON line in the output file: {line.strip()}")
        print(f"Loaded {len(processed_prompts)} processed records.")

    # 2. Read the input file and filter for unprocessed tasks
    tasks_to_process = []
    try:
        with open(INPUT_FILE_PATH, 'r', encoding='utf-8') as f_in:
            for line in f_in:
                try:
                    data = json.loads(line)
                    if data.get('prompt') and data['prompt'] not in processed_prompts:
                        tasks_to_process.append(data)
                except json.JSONDecodeError:
                    print(f"\n[Warning] Found an invalid JSON line in the input file, skipping: {line.strip()}")
    except FileNotFoundError:
        print(f"[Critical Error] Input file not found: '{INPUT_FILE_PATH}'")
        return

    if not tasks_to_process:
        print("All data has already been processed. No action needed.")
        return

    print(f"Found {len(tasks_to_process)} new records to process.")

    # 3. Use a thread pool to process tasks in parallel
    api_key_cycler = cycle(API_KEYS)

    with open(OUTPUT_FILE_PATH, 'a', encoding='utf-8') as f_out:
        with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            future_to_data = {
                executor.submit(process_single_item, data_item, next(api_key_cycler)): data_item
                for data_item in tasks_to_process
            }

            for future in tqdm(concurrent.futures.as_completed(future_to_data), total=len(tasks_to_process),
                               desc="Processing Progress"):
                try:
                    processed_data = future.result()
                    if processed_data:
                        f_out.write(json.dumps(processed_data, ensure_ascii=False) + '\n')
                        f_out.flush()
                except Exception as exc:
                    original_data = future_to_data[future]
                    tqdm.write(
                        f'\n[Critical Error] Exception occurred while processing prompt: {original_data.get("prompt", "N/A")[:50]}... Exception: {exc}')

    print("\nProcessing complete!")


# --- Script Entry Point ---
if __name__ == "__main__":
    if not API_KEYS or API_KEYS[0] == "sk-key1...":
        print("[Error] Please configure your API keys in the 'API_KEYS' list.")
    else:
        process_files_parallel()